Cpaniaguam/rlssm simplified interface#955
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a simplified public interface for RLSSM by adding a named-model registry and wrapping the existing RLSSM implementation behind a friendlier constructor that can build an RLSSMConfig from a model name (defaulting to "rldm").
Changes:
- Split the existing RLSSM implementation into an internal base class (
_RLSSM) and a public wrapper (RLSSM) with a simplified constructor. - Added an RLSSM/SSM registry and factory (
get_rlssm_model_config,register_rlssm_model,register_ssm) to construct configs from named models. - Expanded the RLSSM test suite to cover the simplified interface, registry behavior, and the new wrapper semantics.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_rlssm.py |
Adds coverage for the new simplified RLSSM constructor and registry-based config creation. |
src/hssm/rl/rlssm.py |
Renames the prior implementation to _RLSSM and adds the public RLSSM wrapper + blocked-attribute behavior. |
src/hssm/rl/registry.py |
New registry/factory module for named RLSSM models and SSM base logp functions. |
src/hssm/rl/__init__.py |
Exposes _RLSSM and registry helpers in the hssm.rl public API. |
src/hssm/__init__.py |
Exposes register_rlssm_model at the top-level hssm API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Added KeyError for missing SSM in the registry. - Improved ValueError messages for non-callable logp functions. - Implemented runtime checks for list_params and loglik in RLSSM. - Ensured defensive copying of mutable parameters in registry functions.
…rs for RL parameters, bounds, and defaults
…y .computed attribute
…l registration process
AlexanderFengler
left a comment
There was a problem hiding this comment.
Left a few comments, mostly looks good already.
@krishnbera can you take a look too?
| - ``RLSSMConfig``: the config class for RL + SSM models in :mod:`hssm.rl.config`. | ||
| - ``get_rlssm_model_config``: factory that builds a config from a named model. | ||
| - ``register_rlssm_model``: register a custom named RLSSM model. | ||
| - ``register_ssm``: register a custom SSM base logp function. |
There was a problem hiding this comment.
would we want/need a register_learning_process equivalent here? ( @krishnbera )
There was a problem hiding this comment.
yes we should have a way of registering the learning process separately.
…SSM model registration
…od function handling
flowchart TD
USER["User: RLSSM(model='rldm')"]
RLSSM_REG["_RLSSM_REGISTRY\nnamed models\ne.g. 'rldm'"]
SSM_REG["_SSM_REGISTRY\ncustom SSMs only\n(empty by default)"]
MODELCONFIG["hssm.modelconfig\nbuilt-in SSMs\nddm · angle · weibull · ornstein · …"]
CACHE["_SSM_LOGP_CACHE\nlazy ONNX → JAX fn"]
OUTPUT["RLSSMConfig"]
USER --> RLSSM_REG
RLSSM_REG -- decision_process name --> LOOKUP
LOOKUP{"registered\ncustom SSM?"}
LOOKUP -- yes --> SSM_REG
LOOKUP -- no --> MODELCONFIG
SSM_REG & MODELCONFIG --> CACHE
CACHE -- annotated JAX fn --> OUTPUT
RLSSM_REG -- rl params / bounds / learning process --> OUTPUT
|
…_models_config_structure
… organizing test cases
krishnbera
left a comment
There was a problem hiding this comment.
looks good overall. added minor comments.
| Examples | ||
| -------- | ||
| >>> import hssm | ||
| >>> hssm.rl.list_models() |
digicosmos86
left a comment
There was a problem hiding this comment.
Generally it looks good. The use of _RLSSM as a separate class is an interesting choice, but unless someone is aware of how this separation came into being, it's difficult to tease out what functions each class serve from their names alone. It might make it harder to find where things are in the future, and it creates separate sets of documentation that can easily get out of sync. I would merge these two classes
|
|
||
| __all__ = [ | ||
| "RLSSM", | ||
| "_RLSSM", |
There was a problem hiding this comment.
The name "_RLSSM" suggests that this class is "private" and should not directly be accessed by the user, so it should not be exposed in __all__
|
|
||
|
|
||
| class RLSSM(_RLSSM): | ||
| """Reinforcement Learning Sequential Sampling Model — simplified public API. |
There was a problem hiding this comment.
This class is public facing - the simplified public API may cause confusion for the user - they don't know what the "not simplified API" is
| This class wraps :class:`_RLSSM` with a user-friendly constructor that | ||
| accepts a *model* name string (looked up in the named-model registry) and | ||
| optional overrides for *learning_process*, *decision_process*, and | ||
| *choices*. Advanced users can bypass the registry entirely by supplying a | ||
| pre-built *model_config*. |
There was a problem hiding this comment.
Again, this is public-facing API documentation, and this does not mean anything to the user if they are not already familiar with _RLSSM
| self.__dict__["_rlssm_fully_initialized"] = True | ||
|
|
||
| @classproperty | ||
| def list_models(cls) -> dict[str, str | None]: |
There was a problem hiding this comment.
This function has a different name than its counterpart in HSSM. Instead of making it a classproperty, it doesn't really depend on the class. How about making it and its counterpart in HSSM a method accessible through hssm and hssm.rl respectively?
flowchart TD subgraph SSM_Registry A1["_SSM_REGISTRY (dict)"] A2["_SSM_LOGP_CACHE (dict)"] end subgraph RLSSM_Registry B1["_RLSSM_REGISTRY (dict)"] end subgraph User_API C1[register_ssm] C2[register_rlssm_model] C3[get_rlssm_model_config] end C1-->|adds entry|A1 C1-->|adds entry|A2 C2-->|adds entry|B1 C3-->|reads entry|B1 C3-->|reads entry|A1 C3-->|calls _get_ssm_logp|A2 A2-->|lazy build if needed|A1 C3-->|returns RLSSMConfig|D1["RLSSMConfig"] style D1 fill:#1e7a1e,stroke:#fff,stroke-width:2px,color:#fff style A1 fill:#1e3a7a,stroke:#fff,stroke-width:2px,color:#fff style A2 fill:#1e3a7a,stroke:#fff,stroke-width:2px,color:#fff style B1 fill:#7a1e1e,stroke:#fff,stroke-width:2px,color:#fff style C1 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff style C2 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff style C3 fill:#7a7a1e,stroke:#fff,stroke-width:2px,color:#fff